# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test")
load(":copts.bzl", "spiral_copts")

package(default_visibility = ["//visibility:public"])

proto_library(
    name = "serializable_proto",
    srcs = ["serializable.proto"],
)

cc_proto_library(
    name = "serializable_cc_proto",
    deps = [":serializable_proto"],
)

psi_cc_library(
    name = "common",
    hdrs = ["common.h"],
    deps = [
        "@abseil-cpp//absl/types:span",
    ],
)

psi_cc_library(
    name = "util",
    srcs = ["util.cc"],
    hdrs = ["util.h"],
    deps = [
        ":params",
    ],
)

psi_cc_library(
    name = "params",
    srcs = ["params.cc"],
    hdrs = ["params.h"],
    deps = [
        ":common",
        "//psi/algorithm/spiral/arith",
        "//psi/algorithm/spiral/arith:ntt_table",
        "//psi/algorithm/spiral/arith:number_theory",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/crypto/hash:blake3",
        "@yacl//yacl/utils:elapsed_timer",
    ],
)

psi_cc_test(
    name = "params_test",
    srcs = ["params_test.cc"],
    deps = [
        ":common",
        ":params",
        ":util",
        "//psi/algorithm/spiral/arith:ntt_table",
    ],
)

psi_cc_library(
    name = "poly_matrix",
    srcs = ["poly_matrix.cc"],
    hdrs = ["poly_matrix.h"],
    copts = spiral_copts(),
    deps = [
        ":params",
        ":serializable_cc_proto",
        ":util",
        "//psi/algorithm/spiral/arith:arith_params",
        "//psi/algorithm/spiral/arith:ntt",
        "@abseil-cpp//absl/strings",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:parallel",
    ] + select({
        "@platforms//cpu:aarch64": [
            "@sse2neon",
        ],
        "//conditions:default": [],
    }),
)

psi_cc_library(
    name = "poly_matrix_utils",
    srcs = ["poly_matrix_utils.cc"],
    hdrs = ["poly_matrix_utils.h"],
    copts = spiral_copts(),
    deps = [
        ":discrete_gaussian",
        ":params",
        ":poly_matrix",
        ":util",
        "//psi/algorithm/spiral/arith:arith_params",
        "//psi/algorithm/spiral/arith:ntt",
        "@abseil-cpp//absl/strings",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/base:exception",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:parallel",
    ],
)

psi_cc_test(
    name = "poly_matrix_test",
    srcs = ["poly_matrix_test.cc"],
    copts = spiral_copts(),
    deps = [
        ":common",
        ":params",
        ":poly_matrix",
        ":poly_matrix_utils",
        ":util",
        "//psi/algorithm/spiral/arith:ntt_table",
        "@abseil-cpp//absl/types:span",
        "@seal",
        "@yacl//yacl/base:buffer",
        "@yacl//yacl/utils:elapsed_timer",
        "@yacl//yacl/utils:parallel",
    ],
)

psi_cc_library(
    name = "discrete_gaussian",
    srcs = ["discrete_gaussian.cc"],
    hdrs = ["discrete_gaussian.h"],
    deps = [
        ":poly_matrix",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
    ],
)

psi_cc_test(
    name = "discrete_gaussian_test",
    srcs = ["discrete_gaussian_test.cc"],
    deps = [
        ":discrete_gaussian",
        ":params",
        ":poly_matrix",
        ":util",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
    ],
)

psi_cc_library(
    name = "gadget",
    srcs = ["gadget.cc"],
    hdrs = ["gadget.h"],
    deps = [
        ":params",
        ":poly_matrix",
    ],
)

psi_cc_test(
    name = "gadget_test",
    srcs = ["gadget_test.cc"],
    deps = [
        ":gadget",
        ":params",
        ":util",
    ],
)

psi_cc_library(
    name = "public_keys",
    hdrs = ["public_keys.h"],
    deps = [
        ":params",
        ":poly_matrix",
        "@yacl//yacl/base:int128",
    ],
)

psi_cc_library(
    name = "spiral_client",
    srcs = ["spiral_client.cc"],
    hdrs = ["spiral_client.h"],
    copts = spiral_copts(),
    deps = [
        ":common",
        ":discrete_gaussian",
        ":gadget",
        ":params",
        ":poly_matrix",
        ":poly_matrix_utils",
        ":public_keys",
        ":serialize",
        "//psi/algorithm/pir_interface:index_pir",
        "//psi/algorithm/pir_interface:pir_db",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:elapsed_timer",
        "@yacl//yacl/utils:parallel",
    ],
)

psi_cc_test(
    name = "spiral_client_test",
    srcs = ["spiral_client_test.cc"],
    deps = [
        ":params",
        ":poly_matrix_utils",
        ":public_keys",
        ":serialize",
        ":spiral_client",
        ":util",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/crypto/tools:prg",
    ],
)

psi_cc_library(
    name = "spiral_server",
    srcs = ["spiral_server.cc"],
    hdrs = ["spiral_server.h"],
    copts = spiral_copts(),
    deps = [
        ":common",
        ":discrete_gaussian",
        ":gadget",
        ":params",
        ":poly_matrix",
        ":poly_matrix_utils",
        ":public_keys",
        ":serialize",
        ":spiral_client",
        "//psi/algorithm/pir_interface:index_pir",
        "//psi/algorithm/pir_interface:pir_db",
        "@protobuf",
        "@yacl//yacl/base:int128",
        "@yacl//yacl/crypto/rand",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:elapsed_timer",
        "@yacl//yacl/utils:parallel",
    ],
)

psi_cc_test(
    name = "spiral_server_test",
    srcs = ["spiral_server_test.cc"],
    deps = [
        ":gadget",
        ":params",
        ":poly_matrix_utils",
        ":public_keys",
        ":spiral_client",
        ":spiral_server",
        ":util",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/crypto/tools:prg",
    ],
)

psi_cc_test(
    name = "spiral_pir_test",
    srcs = ["spiral_pir_test.cc"],
    copts = spiral_copts(),
    deps = [
        ":gadget",
        ":params",
        ":poly_matrix_utils",
        ":public_keys",
        ":spiral_client",
        ":spiral_server",
        ":util",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:elapsed_timer",
    ],
)

psi_cc_test(
    name = "spiral_pir_load_test",
    srcs = ["spiral_pir_load_test.cc"],
    copts = spiral_copts(),
    deps = [
        ":gadget",
        ":params",
        ":poly_matrix_utils",
        ":public_keys",
        ":spiral_client",
        ":spiral_server",
        ":util",
        "@abseil-cpp//absl/types:span",
        "@protobuf",
        "@yacl//yacl/crypto/tools:prg",
        "@yacl//yacl/utils:elapsed_timer",
    ],
)

psi_cc_library(
    name = "serialize",
    srcs = ["serialize.cc"],
    hdrs = ["serialize.h"],
    deps = [
        ":common",
        ":params",
        ":poly_matrix",
        ":public_keys",
        ":serializable_cc_proto",
        ":util",
        "//psi/algorithm/spiral/arith:ntt_table",
        "@abseil-cpp//absl/types:span",
        "@yacl//yacl/base:buffer",
    ],
)
